library(LocationSmoothedBART)
library(fda)
library(tidyverse)
In this section, we generate random functional data points to simulate our predictors. Each functional predictor is created by adding normally distributed noise to a base function. This process mimics the variability and noise present in real-world data, such as medical images.
#Generate a 1000 for each class
set.seed(34)
X_1 = rnorm(10, sd=1)
X_2 = rnorm(10, sd=1)
X_3 = rnorm(10, sd=1)
X_4 = rnorm(10, sd=1)
X_5 = rnorm(10, sd=1)
c1_s_data = NULL
c2_s_data = NULL
c3_s_data = NULL
c4_s_data = NULL
c5_s_data = NULL
for(i in 1:2000){
c1_s_data = cbind(c1_s_data, X_1 + rnorm(10,sd=0.5) )
c2_s_data = cbind(c2_s_data, X_2 + rnorm(10,sd=0.5) )
c3_s_data = cbind(c3_s_data, X_3 + rnorm(10,sd=0.5) )
c4_s_data = cbind(c4_s_data, X_4 + rnorm(10,sd=0.5) )
c5_s_data = cbind(c5_s_data, X_5 + rnorm(10,sd=0.5) )
}
sim_data_1 = cbind(c1_s_data, c2_s_data, c3_s_data, c4_s_data, c5_s_data)
Here, we convert the generated data into functional data objects using B-spline basis functions. This step is essential for transforming discrete data points into smooth functions that the lsBART model can process. The B-spline basis allows us to represent the data smoothly and flexibly.
bb = create.bspline.basis(rangeval = c(0,1), norder=8)
############################################
# CREATE THE POINTS
d1 = Data2fd(c1_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d1_points = t(d1$coefs)%*%t(eval.basis(seq(by =0.01,from=0,to=1), bb ))
d2 = Data2fd(c2_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d2_points = t(d2$coefs)%*%t(eval.basis(seq(by =0.01,from=0,to=1), bb ))
d3 = Data2fd(c3_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d3_points = t(d3$coefs)%*%t(eval.basis(seq(by =0.01,from=0,to=1), bb ))
d4 = Data2fd(c4_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d4_points = t(d4$coefs)%*%t(eval.basis(seq(by =0.01,from=0,to=1), bb ))
d5 = Data2fd(c5_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d5_points = t(d5$coefs)%*%t(eval.basis(seq(by =0.01,from=0,to=1), bb ))
d1_points = d1_points+rnorm( dim(d1_points)[2], sd=0.05 )
d2_points = d2_points+rnorm( dim(d2_points)[2], sd=0.05 )
d3_points = d3_points+rnorm( dim(d3_points)[2], sd=0.05)
d4_points = d4_points+rnorm( dim(d4_points)[2], sd=0.05 )
d5_points = d5_points+rnorm( dim(d5_points)[2], sd=0.05)
Visualizing the functional predictors helps us understand the data structure. Here, we plot each of the five functional predictors to see their shapes and the added noise.
In this step, we define the true response variable using the functional predictors. The response is calculated based on specific intervals of the predictors, and Gaussian noise is added to simulate measurement error. Here we create the continuous y variable of interest. While we include code to use the modified Friedman function from the paper, in this example we only include the first two functional predictors using a partial Friedman function, the non-linear sine portion, to illustrate lsBART’s ability to find the useful predictors and their important regions.
t1 = 1:10
t2 = 21:30
t3 = 41:50
t4 = 61:70
t5 = 81:90
friedmanFunc <- function(x){
res = 10*sin(pi*x[1]*x[2]) + 20*(x[3]-0.5)^2 + 10*x[4] + 5*x[5]
return(res)
}
friedmanFunc2 <- function(points, ts){
res = 10*rowSums(sin(pi* points[[1]][, ts[[1]]] * points[[2]][, ts[[2]] ]) ) +
20*rowSums((points[[3]][, ts[[3]] ] -0.5)^2) +
10*rowSums((points[[4]][, ts[[4]] ]))+
5*rowSums((points[[5]][, ts[[5]] ]))
return(res)
}
partialfriedmanFunc3 <- function(points, ts){
res = 10*rowSums(sin(pi* points[[1]][, ts[[1]]] ) )+
20*rowSums((points[[2]][, ts[[2]] ] -0.5)^2)
return(res)
}
num_intvls = dim(d1_points)[2]
points1 = list(d1_points, d2_points)
ts1 = list(t1,t2)
#the truth
true_y = partialfriedmanFunc3(points1, ts1)
# Add Gaussian noise
true_y = true_y + rnorm(length(true_y))
num_predictors = 5
points = cbind(d1_points,
d2_points,
d3_points,
d4_points,
d5_points)
#create Train and Test sets
df = data.frame(points)
df = as_tibble(df)
#use 50% of dataset as training set and 30% as test set
set.seed(1)
sample <- sample(c(TRUE, FALSE), nrow(df), replace=TRUE, prob=c(0.5,0.5))
train <- df[sample, ]
test <- df[!sample, ]
y_train = true_y[sample]
y_test = true_y[!sample]
Here, we display the training and test data to verify the structure and the response variable. (Run code separately, if desired)
The w_lsBART function fits the Bayesian additive regression trees model for continuous outcomes. Here we define the model as ‘m0’. The model is trained on the training data and then used to predict the outcomes on the test data.
m0 = w_lsBART(num_predictors = num_predictors,
num_intvls =num_intvls,
x.train = as.matrix(train),
y.train = y_train,
x.test = as.matrix(test),
sparse = TRUE,
ndpost = 100,
nskip = 100,
ntree = 200,
dart_fp = TRUE,
dart_int = TRUE)
#> *****Into main of wbart
#> *****Data:
#> data:n,p,np: 961, 505, 1039
#> y1,yn: 22.165046, -52.648046
#> x1,x[n*p]: -0.226485, 0.314343
#> xp1,xp[np*p]: 0.343495, 0.985763
#> *****Number of Trees: 200
#> *****Number of Cut Points: 100 ... 100
#> *****burn and ndpost: 100, 100
#> *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,10.952661,3.000000,830.632380
#> *****sigma: 65.300946
#> *****w (weights): 1.000000 ... 1.000000
#> *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 1,0,1,0.5,1,505,1
#> *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 100,100,100,100
#> *****printevery: 100
#> *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
#>
#> MCMC
#> done 0 (out of 200)
#> done 100 (out of 200)
#> time: 3s
#> check counts
#> trcnt,tecnt,temecnt,treedrawscnt: 100,100,100,100
We evaluate the model’s performance by calculating the Root Mean Squared Error (RMSE) of the predictions on the test set. We also plot the predicted vs. actual values to visualize the model’s accuracy. We call ‘yhat.test’ and which returns a matrix of all predicted test values from each posterior sample. We use the colMeans() function to get a point estimate for each test observation..
predictions <- colMeans(m0$yhat.test)
rmse <- sqrt(mean((y_test - predictions)^2))
# Plot predictions vs y_test
plot(y_test, predictions, main="Predictions vs Actual",
xlab="Actual y_test", ylab="Predictions", pch=19, col="blue")
abline(0, 1, col="red", lwd=2)
Here we see that the predictions follow a fairly straight line, with a Root Mean Squared Error value of 9.91.
Now we examine the probabilities from the trained lsBART model to identify which functional predictors were important to the regression task. Clearly, we see that lsBART captures the important functional predictors, X_1 and X_2, while indicating that the other functional predictors (X_3, X_4, and X_5) are irrelevant to the regression task.
par(mfrow=c(1,1))
barplot(colMeans(m0$varprob_fp), main="Functional Predictor Probabilities",
space=0, names.arg = paste0("X_", 1:num_predictors))
Now we highlight one of the best features of lsBART, the ability to interpret important functional predictors and regions of interest within them.
# Adjust plot margins and increase plot size
par(mfrow=c(2,1), mar=c(4, 4, 2, 1))
# Plot for c1_s_data
plot(Data2fd(c1_s_data, basis=bb))
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
#> [1] "done"
abline(v = (ts1[[1]]) / dim(d1_points)[2], lwd=1, col="green", lty=2)
barplot(colMeans(m0$varprob_times[[1]]), main="X_1 Location Probability",
names.arg = as.character(1:num_intvls), space=0)
abline(v = (ts1[[1]]), lwd=1, col="green", lty=2)
# Plot for c2_s_data
plot(Data2fd(c2_s_data, basis=bb))
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
#> [1] "done"
abline(v = (ts1[[2]]) / dim(d1_points)[2], lwd=1, col="green", lty=2)
barplot(colMeans(m0$varprob_times[[2]]), main="X_2 Location Probability",
names.arg = as.character(round((1:num_intvls) / num_intvls, 2)), space=0)
abline(v = (ts1[[2]]), lwd=1, col="green", lty=2)
source("TreeShap_Source.R", local = knitr::knit_global())
#>
#> Attaching package: 'data.table'
#> The following objects are masked from 'package:dplyr':
#>
#> between, first, last
#> The following object is masked from 'package:purrr':
#>
#> transpose
library(LocationSmoothedBART)
library(fda)
library(tidyverse)
library(pROC)
library(treeshap)
library(data.table)
In this section, we generate random functional data points to simulate our predictors. Each functional predictor is created by adding normally distributed noise to a base function. This process mimics the variability and noise present in real-world data, such as medical images.
#Generate a 100 for each class
set.seed(1)
X_1 = rnorm(10, sd=1)
X_2 = rnorm(10, sd=1)
X_3 = rnorm(10, sd=1)
X_4 = rnorm(10, sd=1)
c1_s_data = NULL
c2_s_data = NULL
c3_s_data = NULL
c4_s_data = NULL
for(i in 1:400){
c1_s_data = cbind(c1_s_data, X_1 + rnorm(10,sd=0.5) )
c2_s_data = cbind(c2_s_data, X_2 + rnorm(10,sd=0.5) )
c3_s_data = cbind(c3_s_data, X_3 + rnorm(10,sd=0.5) )
c4_s_data = cbind(c4_s_data, X_4 + rnorm(10,sd=0.5) )
}
sim_data_1 = cbind(c1_s_data, c2_s_data, c3_s_data, c4_s_data)
Here, we convert the generated data into functional data objects using B-spline basis functions. This step is essential for transforming discrete data points into smooth functions that the lsBART model can process. The B-spline basis allows us to represent the data smoothly and flexibly.
bb = create.bspline.basis(rangeval = c(0,1), norder=8)
############################################
# CREATE THE POINTS
p_by=0.01
d1 = Data2fd(c1_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d1_points = t(d1$coefs)%*%t(eval.basis(seq(by =p_by,from=0,to=1), bb ))
d2 = Data2fd(c2_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d2_points = t(d2$coefs)%*%t(eval.basis(seq(by =p_by,from=0,to=1), bb ))
d3 = Data2fd(c3_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d3_points = t(d3$coefs)%*%t(eval.basis(seq(by =p_by,from=0,to=1), bb ))
d4 = Data2fd(c4_s_data, basis=bb)
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
d4_points = t(d4$coefs)%*%t(eval.basis(seq(by =p_by,from=0,to=1), bb ))
d1_points = d1_points+rnorm( dim(d1_points)[2], sd=0.05 )
d2_points = d2_points+rnorm( dim(d2_points)[2], sd=0.05 )
d3_points = d3_points+rnorm( dim(d3_points)[2], sd=0.05)
d4_points = d4_points+rnorm( dim(d4_points)[2], sd=0.05 )
Visualizing the functional predictors helps us understand the data structure. Here, we plot each of the four functional predictors to see their shapes and the added noise.
In this step, we define the true response variable using the functional predictors. The response is calculated based on specific intervals of the predictors, and Gaussian noise is added to simulate measurement error. We then convert the continuous response variable into a binary variable based on whether it is above or below the mean.
The p_lsBART function fits the Bayesian additive regression trees model for binary outcomes. The model here is defined as ‘m0’. The model parameters such as the number of trees (ntree) and the number of posterior samples (ndpost) are set. The model is trained on the training data and then used to predict the outcomes on the test data.
For binary classification, we call ‘prob.test.mean’ from our model m0 to get the . We calculate the classification accuracy and the area under the curve (AUC) to evaluate the model’s performance. The accuracy measures the proportion of correctly classified instances, while the AUC provides a single metric to assess the model’s ability to distinguish between classes.
# Calculate predictions
predictions <- m0$prob.test.mean
# Calculate classification accuracy
predicted_classes <- ifelse(predictions > 0.5, 1, 0)
accuracy <- mean(predicted_classes == y_test)
print(paste("Classification Accuracy:", round(accuracy, 2)))
#> [1] "Classification Accuracy: 0.89"
# Plot ROC and calculate AUC
roc_obj <- roc(y_test, predictions)
#> Setting levels: control = 0, case = 1
#> Setting direction: controls < cases
auc_val <- auc(roc_obj)
print(paste("AUC:", round(auc_val, 2)))
#> [1] "AUC: 0.98"
# Plot AUC curve
plot(roc_obj, main = paste("Test Set AUC Curve (", round(auc_val, 2), ")" ) )
abline(0, 1, col = "red", lty = 2)
We see the model performs well, with a classification accuracy of 0.89
and an AUC value of 0.98.
Now we see from the probabilities of the trained lsBART model, which functional predictors were important to the regression task. Clearly, we see that lsBART captures the important functional predictors, X_1 and X_2, while indicating that the other functional predictors (X_3, X_4, and X_5) are irrelevant to the regression task.
par(mfrow=c(1,1))
barplot(colMeans(m0$varprob_fp), main="Functional Predictor Probabilities",
space=0, names.arg = paste0("X_", 1:num_predictors))
Now we highlight one of the best features of lsBART, the ability to interpret important functional predictors and regions of interest within them.
# Define desired observations to look at with treeshap
obs_id = 2
# Adjust plot margins and increase plot size
par(mfrow=c(2,1), mar=c(4, 4, 2, 1))
# Plot for c1_s_data
plot(Data2fd(c1_s_data, basis=bb))
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
#> [1] "done"
abline(v = (ts1[[1]]) / dim(d1_points)[2], lwd=1, col="green", lty=2)
barplot(colMeans(m0$varprob_times[[1]]), main="X_1 Location Probability",
names.arg = as.character(1:num_intvls), space=0)
abline(v = (ts1[[1]]), lwd=2, col="green", lty=2)
# Plot for c2_s_data
plot(Data2fd(c2_s_data, basis=bb))
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 10 )
#> [1] "done"
abline(v = (ts1[[2]]) / dim(d1_points)[2], lwd=1, col="green", lty=2)
barplot(colMeans(m0$varprob_times[[2]]), main="X_2 Location Probability",
names.arg = as.character(round((1:num_intvls) / num_intvls, 2)), space=0)
abline(v = (ts1[[2]]), lwd=2, col="green", lty=2)
TreeShap is used to understand the contribution of each feature to the prediction of a single observation. Here, we apply the TreeShap algorithm to observation id: 2 from the test set to interpret why it was classified as 1.
# Get the names
var_names2 = c("X_1", "X_2", "X_3", "X_4" )
obs1= data.frame(test)[obs_id,]
unif_list2=NULL
#Qick estimate
for (i in seq(5, ndpost, by = ndpost/5)) {
bart_unify = BART.unify(m0, data.frame(df), ntree, ndpost, i )
treeshap1 <- treeshap(bart_unify, obs1 , verbose = F, interactions = F)
unif_list2 = rbind(unif_list2, treeshap1$shaps)
print(round(i/ndpost,2)) #Percent complete
}
#> [1] 0.01
#> [1] 0.21
#> [1] 0.41
#> [1] 0.61
#> [1] 0.81
#Full estimate, but slow!!! Run if you have patience
#for (i in 1:ndpost) {
# bart_unify = BART.unify(m0, data.frame(df), ntree, ndpost, i )
#treeshap1 <- treeshap(bart_unify, obs1 , verbose = F, interactions = F)
#unif_list2 = rbind(unif_list2, treeshap1$shaps)
#print(round(i/ndpost,2)) #Percent complete
#}
The following bar plot shows the average Shapley values for each functional predictor. This visualization helps identify which predictor functions had the largest influence on the prediction of the selected observation.
ggplot(df_shapely_total, aes(x=name, y=shapely, fill = shapely>0)) +
geom_bar(stat="identity") +
theme_bw()+
coord_flip()+
ggtitle("Predictor Importance")+
ylab("")+
xlab("")+
scale_fill_manual(values = c("darkblue", "skyblue"),
name = "Sign",
breaks = c(TRUE, FALSE),
labels = c("Positive", "Negative"))+
theme(axis.text.x = element_text(angle = 0, hjust = 1))+
theme(plot.title = element_text(size = 20, face = "bold"),
axis.title.x = element_text(size = 18),
axis.title.y = element_text(size = 18),
legend.position = "none",
axis.text=element_text(size=14))
Here we see which predictor function had the largets influence on the prediction of observation 1. It turns out to be X_2. From this we infer that X_2 was the most important predictor function for the observation 1.
Finally, we examine the Shapley values of observation 1’s X_2 predictor function to identify the specific regions within X_2 that were most influential in the classification.
i=2
plot_data= obs1[(num_intvls*(i-1)+1):(num_intvls*i)]
plot_points_imp = colMeans(unif_list2)[(num_intvls*(i-1)+1):(num_intvls*i)]
#good transformation for visualization
plot_points_imp2 =mean(plot_data)*plot_points_imp/max(plot_points_imp)
#> Warning in mean.default(plot_data): argument is not numeric or logical:
#> returning NA
# See the mean good kidney:
pd = Data2fd(as.numeric(plot_data))
#> 'y' is missing, using 'argvals'
#> 'argvals' is missing; using seq( 0 , 1 , length= 101 )
df_shape_loc11 = tibble( percentage = c(0, pd$basis$params, 1),
line = pd$coefs[c(-1, -length(pd$coefs))] ,
shapely = plot_points_imp
)
# df_shape_loc11
multiplier <- max(df_shape_loc11$line)/max(df_shape_loc11$shapely)
ggplot(df_shape_loc11, aes(x = percentage)) +
geom_bar(aes(y = shapely*multiplier, fill = shapely>0), stat = "identity", alpha = 0.7) +
geom_line(aes(y = line), color = "purple", linewidth = 1.5, alpha = 0.5) +
scale_fill_manual(values = c("darkblue", "skyblue"),
name = "Sign",
breaks = c(TRUE, FALSE),
labels = c("Positive", "Negative"))+
ggtitle("")+
ylab(" ")+
xlab("t")+
scale_y_continuous(name = "X_2(t)",
sec.axis = sec_axis(~./multiplier, name=""))+
theme_bw()+
theme(axis.text.x = element_text(angle = 0, hjust = 1))+
theme(plot.title = element_text(size = 20, face = "bold"),
axis.title.x = element_text(size = 18),
axis.title.y = element_text(size = 16),
legend.position = "none",
axis.text=element_text(size=18))
Now we can examine the shapely values of Observation 1’s X_2 predictor function and we can see the location of interest in X_2 just before the 0.50 region on the underlying t grid. This information on the important region of predictor functions is invaluable for researchers to better understand why the lsBART model classifies an observation a certain way.